{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZTO7hf-zM-np",
    "outputId": "e8944765-badd-4a6c-c111-0ca802c1f23b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7e59d435d310>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.utils import clip_grad_norm_\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch.optim as optim\n",
    "from datasets import load_dataset\n",
    "from transformers import pipeline\n",
    "\n",
    "torch.manual_seed(12046)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "dJhQvyIYM-nr"
   },
   "outputs": [],
   "source": [
    "# 一些超参数\n",
    "learning_rate = 5e-5\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "gamma = 1.0\n",
    "lambda_ = 0.95\n",
    "kl_ctl_value = 0.2\n",
    "cliprange = 0.2\n",
    "vf_coef = 0.1\n",
    "# 经过mini_batch_size步后，更新旧模型\n",
    "mini_batch_size = 20\n",
    "grad_clip = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "Bh967HO6M-ns"
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "ZZWmZrRlM-ns"
   },
   "outputs": [],
   "source": [
    "def prepare_input(data):\n",
    "    data['input_ids'] = [tokenizer.encode(data['text'])[:8]]\n",
    "    return data\n",
    "\n",
    "datasets = load_dataset('imdb', split='train[:500]')\n",
    "datasets = datasets.filter(lambda x: len(x['text']) > 20)\n",
    "tokenized = datasets.map(prepare_input, remove_columns=datasets.column_names)\n",
    "tokenized.set_format(type='torch', device=device)\n",
    "example = tokenized[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "cH6wexbYM-nw",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "class A2CLLM(nn.Module):\n",
    "\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.actor = model\n",
    "        # 值函数估计头\n",
    "        self.critic = nn.Linear(model.base_model.embed_dim, 1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        向前传播，为了使代码易懂，该函数只支持单条文本的计算\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，文本，形状为(1, T)\n",
    "        返回\n",
    "        ----\n",
    "        logits ：torch.FloatTensor，logits，形状为(1, T, vs)\n",
    "        values ：torch.FloatTensor，值函数，形状为(1, T)\n",
    "        '''\n",
    "        _res = self.actor(input_ids=x, output_hidden_states=True)\n",
    "        logits = _res.logits\n",
    "        emb = _res.hidden_states[-1]\n",
    "        values = self.critic(emb).squeeze(-1)\n",
    "        return logits, values\n",
    "\n",
    "    def generate(self, idx, max_new_tokens=20):\n",
    "        '''\n",
    "        生成文本\n",
    "        '''\n",
    "        model = self.actor\n",
    "        return model.generate(idx, max_new_tokens=max_new_tokens,\n",
    "                             pad_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "model = A2CLLM(AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "EO_cOnnCM-nw"
   },
   "outputs": [],
   "source": [
    "from peft import LoraConfig, PeftModel\n",
    "\n",
    "def init_peft_model(model):\n",
    "    config = LoraConfig(\n",
    "        r=1,\n",
    "        lora_alpha=8,\n",
    "        target_modules=['c_attn'],\n",
    "        fan_in_fan_out=True,\n",
    "        lora_dropout=0.1,\n",
    "        bias='none',\n",
    "        modules_to_save=['critic'])\n",
    "    return PeftModel(model, config, adapter_name='lora_ppo')\n",
    "\n",
    "# 增加LoRA适配器\n",
    "model = init_peft_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tv2HXFQwM-nw",
    "outputId": "fbd5f1e8-7cfd-4148-df83-1cc6059958b5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logits torch.Size([1, 20, 50257])\n",
      "lnp torch.Size([1, 20])\n",
      "values torch.Size([1, 20])\n"
     ]
    }
   ],
   "source": [
    "def get_forward_result(model, input_ids, response):\n",
    "    '''\n",
    "    记录向前传播的结果，分别是logits，lnp和值函数\n",
    "    为了使代码易懂，该函数只支持单条文本的计算\n",
    "    '''\n",
    "    # 记录背景文本的长度\n",
    "    _, lens = input_ids.shape\n",
    "    logits, values = model(response)\n",
    "    # 计算交叉熵的时候，需要注意logits和标签的对应关系\n",
    "    lnp = -F.cross_entropy(logits[:, :-1, :].transpose(-2, -1), response[:, 1:], reduction='none')\n",
    "    # 只记录针对生成文本的结果，其中L表示生成文本的长度\n",
    "    res = {\n",
    "        # 最后一个位置的logits没有作用\n",
    "        'logits': logits[:, lens-1:-1, :],  # (1, L, vs)\n",
    "        'lnp': lnp[:, lens-1:],             # (1, L)\n",
    "        'values': values[:, lens:]          # (1, L)\n",
    "    }\n",
    "    return res\n",
    "\n",
    "\n",
    "input_ids = example['input_ids']\n",
    "response = model.generate(input_ids)\n",
    "\n",
    "# 验证get_forward_result计算结果的形状是准确的\n",
    "example_re = get_forward_result(model, input_ids, response)\n",
    "for k, v in example_re.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Rkl8rTvBcOQX",
    "outputId": "01dadb70-33c5-4a79-8991-199b24d80d81"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.3356, -0.3501, -0.6011, -0.4132,  1.0261,  0.8811, -0.3165,  0.4929,\n",
      "         -0.9196, -0.3321, -0.2723, -0.1996, -0.6541,  0.1892,  0.6956,  0.3488,\n",
      "          0.2956,  0.3583,  0.2754,  0.5844,  0.7313,  0.1374,  0.5127, -0.1030,\n",
      "          0.5666, -0.0081,  0.3219, -0.0353]], device='cuda:0',\n",
      "       grad_fn=<SubBackward0>)\n",
      "tensor([[ 0.0418,  0.5579,  0.5273,  1.0549,  0.5402,  0.1473,  0.3205,  0.0311,\n",
      "          0.6900, -0.2323,  0.1526,  0.4450,  0.1746,  0.6160, -0.2214, -0.1989,\n",
      "          0.1022,  0.2701, -0.0173, -0.0539, -0.1477,  0.0678, -0.0153, -0.6429,\n",
      "         -0.3822, -0.4266, -0.2184, -0.4352]], device='cuda:0',\n",
      "       grad_fn=<SubBackward0>)\n",
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0.]], device='cuda:0', grad_fn=<SubBackward0>)\n"
     ]
    }
   ],
   "source": [
    "def turn_on_train_mode(model, target):\n",
    "    '''\n",
    "    只将模型中的特定组件设置为训练模式\n",
    "    '''\n",
    "    for name, module in model.named_modules():\n",
    "        if name.split('.')[-1] in target:\n",
    "            module.train()\n",
    "    return model\n",
    "\n",
    "def _test_turn_on_train_mode():\n",
    "    '''\n",
    "    测试turn_on_train_mode是否正确\n",
    "    '''\n",
    "    test_model = A2CLLM(\n",
    "        AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)\n",
    "    config = LoraConfig(\n",
    "        r=1,\n",
    "        lora_alpha=8,\n",
    "        target_modules=['c_attn'],\n",
    "        fan_in_fan_out=True,\n",
    "        lora_dropout=0.1,\n",
    "        bias='none',\n",
    "        init_lora_weights=False)\n",
    "    test_model = PeftModel(test_model, config, adapter_name='lora_ppo')\n",
    "    # 模型处于训练模式，由于随机失活的原因，每次运算的结果都不相同\n",
    "    test_model.train()\n",
    "    v1 = test_model(response)[1]\n",
    "    v2 = test_model(response)[1]\n",
    "    print(v1 - v2)\n",
    "\n",
    "    test_model.eval()\n",
    "    # 只将LoRA换至训练模式，由于LoRA里的随机失活，每次运算的结果也不相同\n",
    "    turn_on_train_mode(test_model, ['c_attn'])\n",
    "    v1 = test_model(response)[1]\n",
    "    v2 = test_model(response)[1]\n",
    "    print(v1 - v2)\n",
    "\n",
    "    test_model.eval()\n",
    "    turn_on_train_mode(test_model, ['c_attn'])\n",
    "    # 禁用LoRA之后，运算结果会相同\n",
    "    with test_model.disable_adapter():\n",
    "        v1 = test_model(response)[1]\n",
    "        v2 = test_model(response)[1]\n",
    "        print(v1 - v2)\n",
    "\n",
    "_test_turn_on_train_mode()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wW8OBSb6M-nx",
    "outputId": "a2faedc5-e036-4c3e-b29a-0f4acb924448"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9959])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class RewardModel(nn.Module):\n",
    "\n",
    "    def __init__(self, tokenizer):\n",
    "        '''\n",
    "        评分模型\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.model = pipeline(\"sentiment-analysis\", model='lvwerra/distilbert-imdb')\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        向前传播，为了使代码易懂，该函数只支持单条文本的计算\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，文本，形状为(1, T)\n",
    "        返回\n",
    "        ----\n",
    "        re ：torch.FloatTensor，评分，形状为(1)\n",
    "        '''\n",
    "        re = []\n",
    "        x = [self.tokenizer.decode(i) for i in x]\n",
    "        # 此处的x等于背景文本+生成文本，因此得到的scores稍有不妥\n",
    "        # 更准确的做法是只对生成文本进行评分\n",
    "        scores = self.model(x)\n",
    "        for s in scores:\n",
    "            # 将POSITIVE的概率视为评分\n",
    "            if s['label'] == 'POSITIVE':\n",
    "                re.append(s['score'])\n",
    "            else:\n",
    "                re.append(1 - s['score'])\n",
    "        return torch.tensor(re)\n",
    "\n",
    "r_model = RewardModel(tokenizer).to(device)\n",
    "r_model(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1rGQcB0PM-nx",
    "outputId": "2dfc1902-72f5-43af-9234-b224ebd52959"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 20])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compute_rewards(r_model, response, lnp, ref_lnp):\n",
    "    '''\n",
    "    定义游戏奖励\n",
    "    为了使代码易懂，该函数只支持单条文本的计算\n",
    "    '''\n",
    "    # scores的形状为(1), lnp的形状为(1, L), ref_lnp的形状为(1, L)\n",
    "    # r_model：评分模型，response：模型生成的回答\n",
    "    # lnp：新/旧模型的概率对数，ref_lnp：参考模型的概率对数\n",
    "    scores = r_model(response)\n",
    "    rewards = []\n",
    "    for score, lnprob, ref_lnprob in zip(scores, lnp, ref_lnp):\n",
    "        kl = lnprob - ref_lnprob     # (   L)\n",
    "        # kl_ctl_value是调节KL penalty的系数，大于0\n",
    "        reward = -kl_ctl_value * kl  # (   L)\n",
    "        # 游戏奖励等于模型评分 + KL penalty\n",
    "        reward[-1] += score          # (   L)\n",
    "        rewards.append(reward)\n",
    "    return torch.stack(rewards)      # (1, L)\n",
    "\n",
    "# 得到参考模型的结果\n",
    "with torch.no_grad():\n",
    "    with model.disable_adapter():\n",
    "        ref_example_re = get_forward_result(model, input_ids, response)\n",
    "\n",
    "rewards = compute_rewards(r_model, response, example_re['lnp'], ref_example_re['lnp'])\n",
    "rewards.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "huNDSgciM-nx"
   },
   "outputs": [],
   "source": [
    "class GAE:\n",
    "\n",
    "    def __init__(self, gamma, lambda_):\n",
    "        self.gamma = gamma\n",
    "        self.lambda_ = lambda_\n",
    "\n",
    "    def __call__(self, rewards, values):\n",
    "        # 优势函数\n",
    "        advantages = []\n",
    "        last_advantage = 0\n",
    "        vt_next = 0\n",
    "        for r, vt in zip(reversed(rewards), reversed(values)):\n",
    "            delta = r + self.gamma * vt_next - vt\n",
    "            last_advantage = delta + self.gamma * self.lambda_ * last_advantage\n",
    "            advantages.insert(0, last_advantage)\n",
    "            vt_next = vt\n",
    "\n",
    "        return torch.stack(advantages)\n",
    "\n",
    "gae = GAE(gamma, lambda_)\n",
    "advantages = gae(rewards, example_re['values'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MU3Sz6iwM-ny",
    "outputId": "c104d182-1547-4596-8092-4d1d117cbe95"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.2746, device='cuda:0', grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compute_loss(old_lnp, lnp, vpred, advantages):\n",
    "    '''\n",
    "    定义模型损失\n",
    "    为了使代码易懂，该函数只支持单条文本的计算\n",
    "    '''\n",
    "    # old_lnp：旧模型的概率对数，形状为(1, L)\n",
    "    # lnp：新/旧模型的概率对数，形状为(1, L)\n",
    "    # vpred：值函数，形状为(1, L)\n",
    "    # advantages：优势函数，形状为(1, L)\n",
    "    # 值函数损失\n",
    "    vf_loss = -advantages * vpred\n",
    "    # 策略损失\n",
    "    ratio = torch.exp(lnp - old_lnp)\n",
    "    pg_losses = -advantages * ratio\n",
    "    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)\n",
    "    pg_loss = torch.max(pg_losses, pg_losses2)\n",
    "    # 整体损失\n",
    "    loss = pg_loss.mean() + vf_coef * vf_loss.mean()\n",
    "    return loss\n",
    "\n",
    "compute_loss(example_re['lnp'], example_re['lnp'], example_re['values'], advantages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "v_yIwG4tM-nz",
    "outputId": "adb33a65-0b81-4955-f4fb-85be56dc33f3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[   40, 26399,   314,  3001,   327, 47269, 20958,    12]],\n",
       "        device='cuda:0'),\n",
       " tensor([[    1,    40,  1703, 44269,    25, 12550,     1,   318]],\n",
       "        device='cuda:0')]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def play_game(model, r_model, gae, data):\n",
    "    model.eval()\n",
    "    # 分别是背景文本，回复，向前传播结果和优势函数\n",
    "    all_input_ids, all_response, all_res, all_advantages = [], [], [], []\n",
    "    for input_ids in data['input_ids']:\n",
    "        all_input_ids.append(input_ids)\n",
    "        # 生成评论\n",
    "        response = model.generate(input_ids)\n",
    "        all_response.append(response)\n",
    "        with torch.no_grad():\n",
    "            # 记录旧模型数据\n",
    "            res = get_forward_result(model, input_ids, response)\n",
    "            all_res.append(res)\n",
    "            # 记录参考模型数据\n",
    "            with model.disable_adapter():\n",
    "                ref_res = get_forward_result(model, input_ids, response)\n",
    "            rewards = compute_rewards(r_model, response, res['lnp'], ref_res['lnp'])\n",
    "            all_advantages.append(gae(rewards, res['values']))\n",
    "    # 只将LoRA适配器切换至训练模式\n",
    "    turn_on_train_mode(model, ['c_attn'])\n",
    "    return all_input_ids, all_response, all_res, all_advantages\n",
    "\n",
    "play_game(model, r_model, gae, tokenized[:2])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "sPjg11nEM-nz",
    "outputId": "bcc0c350-c4bd-4485-85bc-2c718bc6a020"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 0.5244841426610947, 'ref_score': 0.5244841426610947}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def estimate_rewards(r_model, model, all_input_ids):\n",
    "    '''\n",
    "    预估模型评分\n",
    "    '''\n",
    "    re = {}\n",
    "    # 将模型切换至评估模式\n",
    "    model.eval()\n",
    "    for input_ids in all_input_ids:\n",
    "        # 生成文本\n",
    "        response = model.generate(input_ids)\n",
    "        # 记录评分\n",
    "        re['score'] = re.get('score', 0) + r_model(response).item()\n",
    "        # 记录参考模型的评分\n",
    "        with model.disable_adapter():\n",
    "            response = model.generate(input_ids)\n",
    "            re['ref_score'] = re.get('ref_score', 0) + r_model(response).item()\n",
    "    re['score'] /= len(all_input_ids)\n",
    "    re['ref_score'] /= len(all_input_ids)\n",
    "    # 只将LoRA适配器切换至训练模式\n",
    "    turn_on_train_mode(model, ['c_attn'])\n",
    "    return re\n",
    "\n",
    "estimate_rewards(r_model, model, tokenized[:20]['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mizmt8YrM-n0",
    "outputId": "f6614c34-3780-4655-821a-061daa007119"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step    0: score 0.5412, ref_score 0.5085\n",
      "step    1: score 0.5412, ref_score 0.5085\n",
      "step    2: score 0.5085, ref_score 0.5085\n",
      "step    3: score 0.5412, ref_score 0.5085\n",
      "step    4: score 0.5180, ref_score 0.5085\n",
      "step    5: score 0.5182, ref_score 0.5085\n",
      "step    6: score 0.4743, ref_score 0.5085\n",
      "step    7: score 0.4743, ref_score 0.5085\n",
      "step    8: score 0.4741, ref_score 0.5085\n",
      "step    9: score 0.4741, ref_score 0.5085\n",
      "step   10: score 0.4725, ref_score 0.5085\n",
      "step   11: score 0.5210, ref_score 0.5085\n",
      "step   12: score 0.5225, ref_score 0.5085\n",
      "step   13: score 0.5168, ref_score 0.5085\n",
      "step   14: score 0.5184, ref_score 0.5085\n",
      "step   15: score 0.5135, ref_score 0.5085\n",
      "step   16: score 0.5147, ref_score 0.5085\n",
      "step   17: score 0.5129, ref_score 0.5085\n",
      "step   18: score 0.6062, ref_score 0.5085\n",
      "step   19: score 0.6182, ref_score 0.5085\n",
      "step   20: score 0.6737, ref_score 0.5085\n",
      "step   21: score 0.6730, ref_score 0.5085\n",
      "step   22: score 0.6731, ref_score 0.5085\n",
      "step   23: score 0.6724, ref_score 0.5085\n"
     ]
    }
   ],
   "source": [
    "steps = datasets.num_rows // mini_batch_size\n",
    "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for s in range(steps-1):\n",
    "    data = tokenized[s * mini_batch_size: (s + 1) * mini_batch_size]\n",
    "    # 进行游戏，收集数据。play_game返回的数据都是无法计算梯度的\n",
    "    # 在play_game中，会基于model生成参考模型\n",
    "    input_ids, response, old_res, advantages = play_game(model, r_model, gae, data)\n",
    "    # 循环完成之后，才用新模型替换旧模型\n",
    "    for _ids, _resp, _old_res, _ad in zip(input_ids, response, old_res, advantages):\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        # 收集新模型的数据，model_res里面的数据可以计算梯度\n",
    "        model_res = get_forward_result(model, _ids, _resp)\n",
    "        loss = compute_loss(_old_res['lnp'], model_res['lnp'], model_res['values'], _ad)\n",
    "        loss.backward()\n",
    "        # 梯度裁剪\n",
    "        clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        optimizer.step()\n",
    "    # 将最后一个批次数据作为测试集\n",
    "    res = estimate_rewards(r_model, model, tokenized[-mini_batch_size:]['input_ids'])\n",
    "    print(f'step {s:>4}: score {res[\"score\"]:.4f}, ref_score {res[\"ref_score\"]:.4f}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
